Torch 版常微分方程

🔖 python
Author

Guangyao Zhao

Published

Apr 3, 2023

导入必要的第三方库:

import numpy as np
import torch
import torch.nn as nn
from torchdiffeq import odeint
import matplotlib.pyplot as plt

定义常微分方程类:

device = torch.device("cpu")  # 定义计算为 cpu
data_size = 100  # samples in dataset


class glycolysis(nn.Module):
    def __init__(self):
        super(
            glycolysis, self
        ).__init__()  # 当我们定义一个子类时,如果这个子类继承自一个父类,那么在子类的构造函数中应该显式地调用父类的构造函数,以确保父类的属性和方法也能够被正确地初始化。
        self.J0 = 2.5  # mM min-1
        self.k1 = 100.0  # mM-1 min-1
        self.k2 = 6.0  # mM min-1
        self.k3 = 16.0  # mM min-1
        self.k4 = 100.0  # mM min-1
        self.k5 = 1.28  # mM min-1
        self.k6 = 12.0  # mM min-1
        self.k = 1.8  # min-1
        self.kappa = 13.0  # min-1
        self.q = 4.0
        self.K1 = 0.52  # mM
        self.psi = 0.1
        self.N = 1.0  # mM
        self.A = 4.0  # mM

    def forward(self, t, y):
        S1 = y.view(-1, 7)[:, 0]
        S2 = y.view(-1, 7)[:, 1]
        S3 = y.view(-1, 7)[:, 2]
        S4 = y.view(-1, 7)[:, 3]
        S5 = y.view(-1, 7)[:, 4]
        S6 = y.view(-1, 7)[:, 5]
        S7 = y.view(-1, 7)[:, 6]

        dS1 = self.J0 - (self.k1 * S1 * S6) / (1 + (S6 / self.K1) ** self.q)
        dS2 = (
            2.0 * (self.k1 * S1 * S6) / (1 + (S6 / self.K1) ** self.q)
            - self.k2 * S2 * (self.N - S5)
            - self.k6 * S2 * S5
        )
        dS3 = self.k2 * S2 * (self.N - S5) - self.k3 * S3 * (self.A - S6)
        dS4 = self.k3 * S3 * (self.A - S6) - self.k4 * S4 * S5 - self.kappa * (S4 - S7)
        dS5 = self.k2 * S2 * (self.N - S5) - self.k4 * S4 * S5 - self.k6 * S2 * S5
        dS6 = (
            -2.0 * (self.k1 * S1 * S6) / (1 + (S6 / self.K1) ** self.q)
            + 2.0 * self.k3 * S3 * (self.A - S6)
            - self.k5 * S6
        )
        dS7 = self.psi * self.kappa * (S4 - S7) - self.k * S7

        return torch.stack([dS1, dS2, dS3, dS4, dS5, dS6, dS7], dim=1).to(device)

初始化变量,进行计算:

# Initial condition, time span & parameters
y0 = torch.tensor([[1.6, 1.5, 0.2, 0.35, 0.3, 2.67, 0.1]]).to(
    device
)  #! initial condition
t = torch.linspace(0.0, 4.0, data_size).to(device)  #! saveat
p = torch.tensor(
    [2.5, 100.0, 6.0, 16.0, 100.0, 1.28, 12.0, 1.8, 13.0, 4.0, 0.52, 0.1, 1.0, 4.0]
).to(
    device
)  #! mechanistic model parameters

# Disable backprop, solve system of ODEs
with torch.no_grad():  # 所有的计算都将使用 float64 类型,而不是默认的 float32。也会减少内存消耗
    true_y = odeint(glycolysis(), y0, t, method="dopri5")

可视化结果:

# 绘图
fig = plt.figure()
axes = fig.add_subplot()
axes.plot(t, true_y[:, 0, 0])  # 第一维时间;第二维可以理解为初始状态的组数,因为只有一组初始状态,所以此数值为1,第三维是第几个变量